{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Training and Inference Module\n",
    "\n",
    "We modularized commonly used codes for training and inference in the `module` (or `mod` for short) package. This package provides intermediate-level and high-level interface for executing predefined networks. \n",
    "\n",
    "\n",
    "## Basic Usage\n",
    "\n",
    "### Preliminary\n",
    "\n",
    "In this tutorial, we will use a simple multilayer perception for 10 classes and a synthetic dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from data_iter import SyntheticData\n",
    "\n",
    "# mlp\n",
    "net = mx.sym.Variable('data')\n",
    "net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)\n",
    "net = mx.sym.Activation(net, name='relu1', act_type=\"relu\")\n",
    "net = mx.sym.FullyConnected(net, name='fc2', num_hidden=10)\n",
    "net = mx.sym.SoftmaxOutput(net, name='softmax')\n",
    "\n",
    "# synthetic 10 classes dataset with 128 dimension \n",
    "data = SyntheticData(10, 128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Module\n",
    "\n",
    "The most widely used module class is `Module`, which wraps a `Symbol` and one or more `Executor`s.\n",
    "\n",
    "We construct a module by specify\n",
    "\n",
    "- symbol : the network Symbol\n",
    "- context : the device (or a list of devices) for execution\n",
    "- data_names : the list of data variable names \n",
    "- label_names : the list of label variable names\n",
    "\n",
    "One can refer to [data.ipynb](./data.ipynb) for more explanations about the last two arguments. Here we have only one data named `data`, and one label, with the name `softmax_label`, which is automatically named for us following the name `softmax` we specified for the `SoftmaxOutput` operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "mod = mx.mod.Module(symbol=net, \n",
    "                    context=mx.cpu(),\n",
    "                    data_names=['data'], \n",
    "                    label_names=['softmax_label'])\n",
    "assert len(mod.data_names) == 1 and mod.data_names[0] == 'data', \"Module data bind failed.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train, Predict, and Evaluate\n",
    "\n",
    "Modules provide high-level APIs for training, predicting and evaluating. To fit a module, simply call the `fit` function with some `DataIters`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Epoch[0] Train-accuracy=0.090625\n",
      "INFO:root:Epoch[0] Time cost=0.019\n",
      "INFO:root:Epoch[0] Validation-accuracy=0.187500\n",
      "INFO:root:Epoch[1] Train-accuracy=0.159375\n",
      "INFO:root:Epoch[1] Time cost=0.018\n",
      "INFO:root:Epoch[1] Validation-accuracy=0.103125\n",
      "INFO:root:Epoch[2] Train-accuracy=0.162500\n",
      "INFO:root:Epoch[2] Time cost=0.022\n",
      "INFO:root:Epoch[2] Validation-accuracy=0.200000\n",
      "INFO:root:Epoch[3] Train-accuracy=0.181250\n",
      "INFO:root:Epoch[3] Time cost=0.022\n",
      "INFO:root:Epoch[3] Validation-accuracy=0.096875\n",
      "INFO:root:Epoch[4] Train-accuracy=0.118750\n",
      "INFO:root:Epoch[4] Time cost=0.022\n",
      "INFO:root:Epoch[4] Validation-accuracy=0.193750\n"
     ]
    }
   ],
   "source": [
    "# Output may vary\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "\n",
    "batch_size=32\n",
    "mod.fit(data.get_iter(batch_size), \n",
    "        eval_data=data.get_iter(batch_size),\n",
    "        optimizer='sgd',\n",
    "        optimizer_params={'learning_rate':0.1},\n",
    "        eval_metric='acc',\n",
    "        num_epoch=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To predict with a module, simply call `predict()` with a `DataIter`. It will collect and return all the prediction results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'shape of predict: (320L, 10L)'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = mod.predict(data.get_iter(batch_size))\n",
    "'shape of predict: %s' % (y.shape,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another convenient API for prediction in the case where the prediction results might be too large to fit in the memory is `iter_predict`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch 0, accuracy 0.156250\n",
      "batch 1, accuracy 0.312500\n",
      "batch 2, accuracy 0.312500\n",
      "batch 3, accuracy 0.187500\n",
      "batch 4, accuracy 0.187500\n",
      "batch 5, accuracy 0.125000\n",
      "batch 6, accuracy 0.250000\n",
      "batch 7, accuracy 0.187500\n",
      "batch 8, accuracy 0.156250\n",
      "batch 9, accuracy 0.187500\n"
     ]
    }
   ],
   "source": [
    "# Output may vary\n",
    "for preds, i_batch, batch in mod.iter_predict(data.get_iter(batch_size)):\n",
    "    pred_label = preds[0].asnumpy().argmax(axis=1)\n",
    "    label = batch.label[0].asnumpy().astype('int32')\n",
    "    print('batch %d, accuracy %f' % (i_batch, float(sum(pred_label==label))/len(label)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we do not need the prediction outputs, but just need to evaluate on a test set, we can call the `score()` function with a `DataIter` and a `EvalMetric`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('mse', 29.322532463073731), ('accuracy', 0.2125)]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Output may vary\n",
    "mod.score(data.get_iter(batch_size), ['mse', 'acc'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save and Load\n",
    "\n",
    "We can save the module parameters in each training epoch by using a checkpoint callback."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Epoch[0] Train-accuracy=0.196875\n",
      "INFO:root:Epoch[0] Time cost=0.019\n",
      "INFO:root:Saved checkpoint to \"mx_mlp-0001.params\"\n",
      "INFO:root:Epoch[1] Train-accuracy=0.159375\n",
      "INFO:root:Epoch[1] Time cost=0.019\n",
      "INFO:root:Saved checkpoint to \"mx_mlp-0002.params\"\n",
      "INFO:root:Epoch[2] Train-accuracy=0.290625\n",
      "INFO:root:Epoch[2] Time cost=0.018\n",
      "INFO:root:Saved checkpoint to \"mx_mlp-0003.params\"\n",
      "INFO:root:Epoch[3] Train-accuracy=0.125000\n",
      "INFO:root:Epoch[3] Time cost=0.019\n",
      "INFO:root:Saved checkpoint to \"mx_mlp-0004.params\"\n",
      "INFO:root:Epoch[4] Train-accuracy=0.190625\n",
      "INFO:root:Epoch[4] Time cost=0.018\n",
      "INFO:root:Saved checkpoint to \"mx_mlp-0005.params\"\n"
     ]
    }
   ],
   "source": [
    "# Output may vary\n",
    "# construct a callback function to save checkpoints\n",
    "model_prefix = 'mx_mlp'\n",
    "checkpoint = mx.callback.do_checkpoint(model_prefix)\n",
    "\n",
    "mod = mx.mod.Module(symbol=net)\n",
    "mod.fit(data.get_iter(batch_size), num_epoch=5, epoch_end_callback=checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load the saved module parameters, call the `load_checkpoint` function. It load the Symbol and the associated parameters. We can then set the loaded parameters into the module. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)\n",
    "print(sym.tojson() == net.tojson())\n",
    "\n",
    "# assign the loaded parameters to the module\n",
    "mod.set_params(arg_params, aux_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or if we just want to resume training from a saved checkpoint, instead of calling `set_params()`, we can directly call `fit()`, passing the loaded parameters, so that `fit()` knows to start from those parameters instead of initializing from random. We also set the `begin_epoch` so that so that `fit()` knows we are resuming from a previous saved epoch.\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Epoch[3] Train-accuracy=0.168750\n",
      "INFO:root:Epoch[3] Time cost=0.019\n",
      "INFO:root:Epoch[4] Train-accuracy=0.212500\n",
      "INFO:root:Epoch[4] Time cost=0.019\n"
     ]
    }
   ],
   "source": [
    "# Output may vary\n",
    "mod = mx.mod.Module(symbol=sym)\n",
    "mod.fit(data.get_iter(batch_size),\n",
    "        num_epoch=5,\n",
    "        arg_params=arg_params, \n",
    "        aux_params=aux_params,\n",
    "        begin_epoch=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Module as a computation \"machine\"\n",
    "\n",
    "We already seen how to module for basic training and inference. Now we are going to show a more flexiable usage of module.\n",
    "\n",
    "A module represents a computation component. The design purpose of a module is that it abstract a computation “machine”, that accpets `Symbol` programs and data, and then we can run forward, backward, update parameters, etc. \n",
    "\n",
    "We aim to make the APIs easy and flexible to use, especially in the case when we need to use imperative API to work with multiple modules (e.g. stochastic depth network).\n",
    "\n",
    "A module has several states:\n",
    "- **Initial state**. Memory is not allocated yet, not ready for computation yet.\n",
    "- **Binded**. Shapes for inputs, outputs, and parameters are all known, memory allocated, ready for computation.\n",
    "- **Parameter initialized**. For modules with parameters, doing computation before initializing the parameters might result in undefined outputs.\n",
    "- **Optimizer installed**. An optimizer can be installed to a module. After this, the parameters of the module can be updated according to the optimizer after gradients are computed (forward-backward).\n",
    "\n",
    "The following codes implement a simplified `fit()`. Here we used other components including initializer, optimizer, and metric, which are explained in other notebooks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('accuracy', 0.3875)\n"
     ]
    }
   ],
   "source": [
    "# Output may vary\n",
    "# initial state\n",
    "mod = mx.mod.Module(symbol=net)\n",
    "\n",
    "# bind, tell the module the data and label shapes, so\n",
    "# that memory could be allocated on the devices for computation\n",
    "train_iter = data.get_iter(batch_size)\n",
    "mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)\n",
    "\n",
    "# init parameters\n",
    "mod.init_params(initializer=mx.init.Xavier(magnitude=2.))\n",
    "\n",
    "# init optimizer\n",
    "mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))\n",
    "\n",
    "# use accuracy as the metric\n",
    "metric = mx.metric.create('acc')\n",
    "\n",
    "# train one epoch, i.e. going over the data iter one pass\n",
    "for batch in train_iter:\n",
    "    mod.forward(batch, is_train=True)       # compute predictions\n",
    "    mod.update_metric(metric, batch.label)  # accumulate prediction accuracy\n",
    "    mod.backward()                          # compute gradients\n",
    "    mod.update()                            # update parameters using SGD\n",
    "    \n",
    "# training accuracy\n",
    "print(metric.get())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Beside the operations, a module provides a lot of useful information.\n",
    "\n",
    "basic names:\n",
    "- **data_names**: list of string indicating the names of the required data.\n",
    "- **output_names**: list of string indicating the names of the outputs.\n",
    "\n",
    "state information\n",
    "- **binded**: bool, indicating whether the memory buffers needed for computation has been allocated.\n",
    "- **for_training**: whether the module is binded for training (if binded).\n",
    "- **params_initialized**: bool, indicating whether the parameters of this modules has been initialized.\n",
    "- **optimizer_initialized**: bool, indicating whether an optimizer is defined and initialized.\n",
    "- **inputs_need_grad**: bool, indicating whether gradients with respect to the input data is needed. Might be useful when implementing composition of modules.\n",
    "\n",
    "input/output information\n",
    "- **data_shapes**: a list of (name, shape). In theory, since the memory is allocated, we could directly provide the data arrays. But in the case of data parallelization, the data arrays might not be of the same shape as viewed from the external world.\n",
    "- **label_shapes**: a list of (name, shape). This might be [] if the module does not need labels (e.g. it does not contains a loss function at the top), or a module is not binded for training.\n",
    "- **output_shapes**: a list of (name, shape) for outputs of the module.\n",
    "\n",
    "parameters (for modules with parameters)\n",
    "- **get_params()**: return a tuple (arg_params, aux_params). Each of those is a dictionary of name to NDArray mapping. Those NDArray always lives on CPU. The actual parameters used for computing might live on other devices (GPUs), this function will retrieve (a copy of) the latest parameters.\n",
    "- **get_outputs()**: get outputs of the previous forward operation.\n",
    "- **get_input_grads()**: get the gradients with respect to the inputs computed in the previous backward operation.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "([('data', (32, 128))], [('softmax_label', (32,))], [('softmax_output', (32, 10L))])\n",
      "({'fc2_bias': <NDArray 10 @cpu(0)>, 'fc2_weight': <NDArray 10x64 @cpu(0)>, 'fc1_bias': <NDArray 64 @cpu(0)>, 'fc1_weight': <NDArray 64x128 @cpu(0)>}, {})\n"
     ]
    }
   ],
   "source": [
    "print((mod.data_shapes, mod.label_shapes, mod.output_shapes))\n",
    "print(mod.get_params())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More on Modules\n",
    "\n",
    "Module simplifies the implementation of new modules. For example\n",
    "\n",
    "- [`SequentialModule`](https://github.com/dmlc/mxnet/blob/master/python/mxnet/module/sequential_module.py) can chain multiple modules together\n",
    "- [`BucketingModule`](https://github.com/dmlc/mxnet/blob/master/python/mxnet/module/bucketing_module.py) is able to handle bucketing, which is useful for various length inputs and outputs\n",
    "- [`PythonModule`](https://github.com/dmlc/mxnet/blob/master/python/mxnet/module/python_module.py) implements many APIs as empty function to ease users to implement customized modules. \n",
    "\n",
    "See also [example/module](https://github.com/dmlc/mxnet/tree/master/example/module) for a list of code examples using the module API.\n",
    "\n",
    "## Implementation\n",
    "\n",
    "The `module` is implemented in python, located at [python/mxnet/module](https://github.com/dmlc/mxnet/tree/master/python/mxnet/module) \n",
    "\n",
    "## Futher Readings\n",
    "\n",
    "- [module API doc](http://mxnet.io/packages/python/module.html#module-interface-api)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
